{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 加载模型\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json\n",
    "import time\n",
    "import web\n",
    "import numpy as np\n",
    "from PIL import Image\n",
    "from config import *\n",
    "from apphelper.image import union_rbox,adjust_box_to_origin,base64_to_PIL\n",
    "from application import trainTicket,idcard \n",
    "if yoloTextFlag =='keras' or AngleModelFlag=='tf' or ocrFlag=='keras':\n",
    "    if GPU:\n",
    "        os.environ[\"CUDA_VISIBLE_DEVICES\"] = str(GPUID)\n",
    "        import tensorflow as tf\n",
    "        from keras import backend as K\n",
    "        config = tf.ConfigProto()\n",
    "        config.gpu_options.allocator_type = 'BFC'\n",
    "        config.gpu_options.per_process_gpu_memory_fraction = 0.3## GPU最大占用量\n",
    "        config.gpu_options.allow_growth = True##GPU是否可动态增加\n",
    "        K.set_session(tf.Session(config=config))\n",
    "        K.get_session().run(tf.global_variables_initializer())\n",
    "    \n",
    "    else:\n",
    "      ##CPU启动\n",
    "      os.environ[\"CUDA_VISIBLE_DEVICES\"] = ''\n",
    "\n",
    "if yoloTextFlag=='opencv':\n",
    "    scale,maxScale = IMGSIZE\n",
    "    from text.opencv_dnn_detect import text_detect\n",
    "elif yoloTextFlag=='darknet':\n",
    "    scale,maxScale = IMGSIZE\n",
    "    from text.darknet_detect import text_detect\n",
    "elif yoloTextFlag=='keras':\n",
    "    scale,maxScale = IMGSIZE[0],2048\n",
    "    from text.keras_detect import  text_detect\n",
    "else:\n",
    "     print( \"err,text engine in keras\\opencv\\darknet\")\n",
    "     \n",
    "from text.opencv_dnn_detect import angle_detect\n",
    "\n",
    "if ocr_redis:\n",
    "    ##多任务并发识别\n",
    "    from helper.redisbase import redisDataBase\n",
    "    ocr = redisDataBase().put_values\n",
    "else:   \n",
    "    from crnn.keys import alphabetChinese,alphabetEnglish\n",
    "    if ocrFlag=='keras':\n",
    "        from crnn.network_keras import CRNN\n",
    "        if chineseModel:\n",
    "            alphabet = alphabetChinese\n",
    "            if LSTMFLAG:\n",
    "                ocrModel = ocrModelKerasLstm\n",
    "            else:\n",
    "                ocrModel = ocrModelKerasDense\n",
    "        else:\n",
    "            ocrModel = ocrModelKerasEng\n",
    "            alphabet = alphabetEnglish\n",
    "            LSTMFLAG = True\n",
    "            \n",
    "    elif ocrFlag=='torch':\n",
    "        from crnn.network_torch import CRNN\n",
    "        if chineseModel:\n",
    "            alphabet = alphabetChinese\n",
    "            if LSTMFLAG:\n",
    "                ocrModel = ocrModelTorchLstm\n",
    "            else:\n",
    "                ocrModel = ocrModelTorchDense\n",
    "                \n",
    "        else:\n",
    "            ocrModel = ocrModelTorchEng\n",
    "            alphabet = alphabetEnglish\n",
    "            LSTMFLAG = True\n",
    "    elif ocrFlag=='opencv':\n",
    "        from crnn.network_dnn import CRNN\n",
    "        ocrModel = ocrModelOpencv\n",
    "        alphabet = alphabetChinese\n",
    "    else:\n",
    "        print( \"err,ocr engine in keras\\opencv\\darknet\")\n",
    "     \n",
    "    nclass = len(alphabet)+1   \n",
    "    if ocrFlag=='opencv':\n",
    "        crnn = CRNN(alphabet=alphabet)\n",
    "    else:\n",
    "        crnn = CRNN( 32, 1, nclass, 256, leakyRelu=False,lstmFlag=LSTMFLAG,GPU=GPU,alphabet=alphabet)\n",
    "    if os.path.exists(ocrModel):\n",
    "        crnn.load_weights(ocrModel)\n",
    "    else:\n",
    "        print(\"download model or tranform model with tools!\")\n",
    "        \n",
    "    ocr = crnn.predict_job\n",
    "    \n",
    "   \n",
    "from main import TextOcrModel\n",
    "\n",
    "model =  TextOcrModel(ocr,text_detect,angle_detect)\n",
    "from apphelper.image import xy_rotate_box,box_rotate,solve\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cv2\n",
    "import numpy as np\n",
    "def plot_box(img,boxes):\n",
    "    blue = (0, 0, 0) #18\n",
    "    tmp = np.copy(img)\n",
    "    for box in boxes:\n",
    "         cv2.rectangle(tmp, (int(box[0]),int(box[1])), (int(box[2]), int(box[3])), blue, 1) #19\n",
    "    \n",
    "    return Image.fromarray(tmp) \n",
    "\n",
    "def plot_boxes(img,angle, result,color=(0,0,0)):\n",
    "    tmp = np.array(img)\n",
    "    c = color\n",
    "    h,w = img.shape[:2]\n",
    "    thick = int((h + w) / 300)\n",
    "    i = 0\n",
    "    if angle in [90,270]:\n",
    "        imgW,imgH = img.shape[:2]\n",
    "        \n",
    "    else:\n",
    "        imgH,imgW= img.shape[:2]\n",
    "\n",
    "    for line in result:\n",
    "        cx =line['cx']\n",
    "        cy = line['cy']\n",
    "        degree =line['degree']\n",
    "        w  = line['w']\n",
    "        h = line['h']\n",
    "\n",
    "        x1,y1,x2,y2,x3,y3,x4,y4 = xy_rotate_box(cx, cy, w, h, degree/180*np.pi)\n",
    "        \n",
    "        x1,y1,x2,y2,x3,y3,x4,y4 = box_rotate([x1,y1,x2,y2,x3,y3,x4,y4],angle=(360-angle)%360,imgH=imgH,imgW=imgW)\n",
    "        cx  =np.mean([x1,x2,x3,x4])\n",
    "        cy  = np.mean([y1,y2,y3,y4])\n",
    "        cv2.line(tmp,(int(x1),int(y1)),(int(x2),int(y2)),c,1)\n",
    "        cv2.line(tmp,(int(x2),int(y2)),(int(x3),int(y3)),c,1)\n",
    "        cv2.line(tmp,(int(x3),int(y3)),(int(x4),int(y4)),c,1)\n",
    "        cv2.line(tmp,(int(x4),int(y4)),(int(x1),int(y1)),c,1)\n",
    "        mess=str(i)\n",
    "        cv2.putText(tmp, mess, (int(cx), int(cy)),0, 1e-3 * h, c, thick // 2)\n",
    "        i+=1\n",
    "    return Image.fromarray(tmp).convert('RGB')\n",
    "\n",
    "\n",
    "def plot_rboxes(img,boxes,color=(0,0,0)):\n",
    "    tmp = np.array(img)\n",
    "    c = color\n",
    "    h,w = img.shape[:2]\n",
    "    thick = int((h + w) / 300)\n",
    "    i = 0\n",
    "\n",
    "\n",
    "    for box in boxes:\n",
    "\n",
    "        x1,y1,x2,y2,x3,y3,x4,y4 = box\n",
    "        \n",
    "        \n",
    "        cx  =np.mean([x1,x2,x3,x4])\n",
    "        cy  = np.mean([y1,y2,y3,y4])\n",
    "        cv2.line(tmp,(int(x1),int(y1)),(int(x2),int(y2)),c,1)\n",
    "        cv2.line(tmp,(int(x2),int(y2)),(int(x3),int(y3)),c,1)\n",
    "        cv2.line(tmp,(int(x3),int(y3)),(int(x4),int(y4)),c,1)\n",
    "        cv2.line(tmp,(int(x4),int(y4)),(int(x1),int(y1)),c,1)\n",
    "        mess=str(i)\n",
    "        cv2.putText(tmp, mess, (int(cx), int(cy)),0, 1e-3 * h, c, thick // 2)\n",
    "        i+=1\n",
    "    return Image.fromarray(tmp).convert('RGB')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "from PIL import Image\n",
    "p = './test/idcard-demo.jpeg'\n",
    "img = cv2.imread(p)\n",
    "\n",
    "h,w = img.shape[:2]\n",
    "timeTake = time.time()\n",
    "scale=608\n",
    "maxScale=2048\n",
    "\n",
    "result,angle= model.model(img,\n",
    "                                    detectAngle=True,##是否进行文字方向检测\n",
    "                                    scale=scale,\n",
    "                                    maxScale=maxScale,\n",
    "                                    MAX_HORIZONTAL_GAP=80,##字符之间的最大间隔，用于文本行的合并\n",
    "                                    MIN_V_OVERLAPS=0.6,\n",
    "                                    MIN_SIZE_SIM=0.6,\n",
    "                                    TEXT_PROPOSALS_MIN_SCORE=0.1,\n",
    "                                    TEXT_PROPOSALS_NMS_THRESH=0.7,\n",
    "                                    TEXT_LINE_NMS_THRESH = 0.9,##文本行之间测iou值\n",
    "                                     LINE_MIN_SCORE=0.1,                                             \n",
    "                                    leftAdjustAlph=0,##对检测的文本行进行向左延伸\n",
    "                                    rightAdjustAlph=0.1,##对检测的文本行进行向右延伸\n",
    "                                   )\n",
    "        \n",
    "timeTake = time.time()-timeTake\n",
    "\n",
    "print('It take:{}s'.format(timeTake))\n",
    "for line in result:\n",
    "    print(line['text'])\n",
    "plot_boxes(img,angle, result,color=(0,0,0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "boxes,scores  = model.detect_box(img,608,2048)\n",
    "plot_box(img,boxes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "chineseocr",
   "language": "python",
   "name": "chineseocr"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
